#!/usr/bin/env python

# reference: c041828_ISO_IEC_14496-12_2005(E).pdf

##################################################
# reader and writer
##################################################

import struct
from cStringIO import StringIO

def skip(stream, n):
	stream.seek(stream.tell() + n)

def skip_zeros(stream, n):
	assert stream.read(n) == '\x00' * n

def read_int(stream):
	return struct.unpack('>i', stream.read(4))[0]

def read_uint(stream):
	return struct.unpack('>I', stream.read(4))[0]

def write_uint(stream, n):
	stream.write(struct.pack('>I', n))

def read_ushort(stream):
	return struct.unpack('>H', stream.read(2))[0]

def read_ulong(stream):
	return struct.unpack('>Q', stream.read(8))[0]

def read_byte(stream):
	return ord(stream.read(1))

def copy_stream(source, target, n):
	buffer_size = 1024*1024
	while n > 0:
		to_read = min(buffer_size, n)
		s = source.read(to_read)
		assert len(s) == to_read, 'no enough data'
		target.write(s)
		n -= to_read

class Atom:
	def __init__(self, type, size, body):
		assert len(type) == 4
		self.type = type
		self.size = size
		self.body = body
	def __str__(self):
		#return '<Atom(%s):%s>' % (self.type, repr(self.body))
		return '<Atom(%s):%s>' % (self.type, '')
	def __repr__(self):
		return str(self)
	def write1(self, stream):
		write_uint(stream, self.size)
		stream.write(self.type)
	def write(self, stream):
		assert type(self.body) == str, '%s: %s' % (self.type, type(self.body))
		assert self.size == 8 + len(self.body)
		self.write1(stream)
		stream.write(self.body)
	def calsize(self):
		return self.size

class CompositeAtom(Atom):
	def __init__(self, type, size, body):
		assert isinstance(body, list)
		Atom.__init__(self, type, size, body)
	def write(self, stream):
		assert type(self.body) == list
		self.write1(stream)
		for atom in self.body:
			atom.write(stream)
	def calsize(self):
		self.size = 8 + sum([atom.calsize() for atom in self.body])
		return self.size
	def get1(self, k):
		for a in self.body:
			if a.type == k:
				return a
		else:
			raise Exception('atom not found: '+k)
	def get(self, *keys):
		atom = self
		for k in keys:
			atom = atom.get1(k)
		return atom
	def get_all(self, k):
		return filter(lambda x: x.type == k, self.body)

class VariableAtom(Atom):
	def __init__(self, type, size, body, variables):
		assert isinstance(body, str)
		Atom.__init__(self, type, size, body)
		self.variables = variables
	def write(self, stream):
		self.write1(stream)
		i = 0
		n = 0
		for name, offset, value in self.variables:
			stream.write(self.body[i:offset])
			write_uint(stream, value)
			n += offset - i + 4
			i = offset + 4
		stream.write(self.body[i:])
		n += len(self.body) - i
		assert n == len(self.body)
	def get(self, k):
		for v in self.variables:
			if v[0] == k:
				return v[2]
		else:
			raise Exception('field not found: '+k)
	def set(self, k, v):
		for i in range(len(self.variables)):
			variable = self.variables[i]
			if variable[0] == k:
				self.variables[i] = (k, variable[1], v)
				break
		else:
			raise Exception('field not found: '+k)

def read_raw(stream, size, left, type):
	assert size == left + 8
	body = stream.read(left)
	return Atom(type, size, body)

def read_body_stream(stream, left):
	body = stream.read(left)
	assert len(body) == left
	return body, StringIO(body)

def read_full_atom(stream):
	value = read_uint(stream)
	version = value >> 24
	flags = value & 0xffffff
	assert version == 0
	return value

def read_mvhd(stream, size, left, type):
	body, stream = read_body_stream(stream, left)
	value = read_full_atom(stream)
	left -= 4

	# new Date(movieTime * 1000 - 2082850791998L); 
	creation_time = read_uint(stream)
	modification_time = read_uint(stream)
	time_scale = read_uint(stream)
	duration = read_uint(stream)
	left -= 16

	qt_preferred_fate = read_uint(stream)
	qt_preferred_volume = read_ushort(stream)
	assert stream.read(10) == '\x00' * 10
	qt_matrixA = read_uint(stream)
	qt_matrixB = read_uint(stream)
	qt_matrixU = read_uint(stream)
	qt_matrixC = read_uint(stream)
	qt_matrixD = read_uint(stream)
	qt_matrixV = read_uint(stream)
	qt_matrixX = read_uint(stream)
	qt_matrixY = read_uint(stream)
	qt_matrixW = read_uint(stream)
	qt_previewTime = read_uint(stream)
	qt_previewDuration = read_uint(stream)
	qt_posterTime = read_uint(stream)
	qt_selectionTime = read_uint(stream)
	qt_selectionDuration = read_uint(stream)
	qt_currentTime = read_uint(stream)
	nextTrackID = read_uint(stream)
	left -= 80
	assert left == 0
	return VariableAtom('mvhd', size, body, [('duration', 16, duration)])

def read_tkhd(stream, size, left, type):
	body, stream = read_body_stream(stream, left)
	value = read_full_atom(stream)
	left -= 4

	# new Date(movieTime * 1000 - 2082850791998L); 
	creation_time = read_uint(stream)
	modification_time = read_uint(stream)
	track_id = read_uint(stream)
	assert stream.read(4) == '\x00' * 4
	duration = read_uint(stream)
	left -= 20

	assert stream.read(8) == '\x00' * 8
	qt_layer = read_ushort(stream)
	qt_alternate_group = read_ushort(stream)
	qt_volume = read_ushort(stream)
	assert stream.read(2) == '\x00\x00'
	qt_matrixA = read_uint(stream)
	qt_matrixB = read_uint(stream)
	qt_matrixU = read_uint(stream)
	qt_matrixC = read_uint(stream)
	qt_matrixD = read_uint(stream)
	qt_matrixV = read_uint(stream)
	qt_matrixX = read_uint(stream)
	qt_matrixY = read_uint(stream)
	qt_matrixW = read_uint(stream)
	qt_track_width = read_uint(stream)
	width = qt_track_width >> 16
	qt_track_height = read_uint(stream)
	height = qt_track_height >> 16
	left -= 60
	assert left == 0
	return VariableAtom('tkhd', size, body, [('duration', 20, duration)])

def read_mdhd(stream, size, left, type):
	body, stream = read_body_stream(stream, left)
	value = read_full_atom(stream)
	left -= 4

	# new Date(movieTime * 1000 - 2082850791998L); 
	creation_time = read_uint(stream)
	modification_time = read_uint(stream)
	time_scale = read_uint(stream)
	duration = read_uint(stream)
	left -= 16

	packed_language = read_ushort(stream)
	qt_quality = read_ushort(stream)
	left -= 4

	assert left == 0
	return VariableAtom('mdhd', size, body, [('duration', 16, duration)])

def read_hdlr(stream, size, left, type):
	body, stream = read_body_stream(stream, left)
	value = read_full_atom(stream)
	left -= 4

	qt_component_type = read_uint(stream)
	handler_type = read_uint(stream)
	qt_component_manufacturer = read_uint(stream)
	qt_component_flags = read_uint(stream)
	qt_component_flags_mask = read_uint(stream)
	left -= 20

	track_name = stream.read(left - 1)
	assert stream.read(1) == '\x00'

	return Atom('hdlr', size, body)

def read_vmhd(stream, size, left, type):
	body, stream = read_body_stream(stream, left)
	value = read_full_atom(stream)
	left -= 4

	assert left == 8
	graphic_mode = read_ushort(stream)
	op_color_read = read_ushort(stream)
	op_color_green = read_ushort(stream)
	op_color_blue = read_ushort(stream)

	return Atom('vmhd', size, body)

def read_stsd(stream, size, left, type):
	value = read_full_atom(stream)
	left -= 4

	entry_count = read_uint(stream)
	left -= 4

	children = []
	for i in range(entry_count):
		atom = read_atom(stream)
		children.append(atom)
		left -= atom.size

	assert left == 0
	#return Atom('stsd', size, children)
	class stsd_atom(Atom):
		def __init__(self, type, size, body):
			Atom.__init__(self, type, size, body)
		def write(self, stream):
			self.write1(stream)
			write_uint(stream, self.body[0])
			write_uint(stream, len(self.body[1]))
			for atom in self.body[1]:
				atom.write(stream)
		def calsize(self):
			oldsize = self.size # TODO: remove
			self.size = 8 + 4 + 4 + sum([atom.calsize() for atom in self.body[1]])
			assert oldsize == self.size, '%s: %d, %d' % (self.type, oldsize, self.size) # TODO: remove
			return self.size
	return stsd_atom('stsd', size, (value, children))

def read_avc1(stream, size, left, type):
	body, stream = read_body_stream(stream, left)

	skip_zeros(stream, 6)
	data_reference_index = read_ushort(stream)
	skip_zeros(stream, 2)
	skip_zeros(stream, 2)
	skip_zeros(stream, 12)
	width = read_ushort(stream)
	height = read_ushort(stream)
	horizontal_rez = read_uint(stream) >> 16
	vertical_rez = read_uint(stream) >> 16
	assert stream.read(4) == '\x00' * 4
	frame_count = read_ushort(stream)
	string_len = read_byte(stream)
	compressor_name = stream.read(31)
	depth = read_ushort(stream)
	assert stream.read(2) == '\xff\xff'
	left -= 78

	child = read_atom(stream)
	assert child.type in ('avcC', 'pasp'), 'if the sub atom is not avcC or pasp (actual %s), you should not cache raw body' % child.type
	left -= child.size
	stream.read(left) # XXX
	return Atom('avc1', size, body)

def read_avcC(stream, size, left, type):
	stream.read(left)
	return Atom('avcC', size, None)

def read_stts(stream, size, left, type):
	value = read_full_atom(stream)
	left -= 4

	entry_count = read_uint(stream)
	assert entry_count == 1
	left -= 4

	samples = []
	for i in range(entry_count):
		sample_count = read_uint(stream)
		sample_duration = read_uint(stream)
		samples.append((sample_count, sample_duration))
		left -= 8
	
	assert left == 0
	#return Atom('stts', size, None)
	class stts_atom(Atom):
		def __init__(self, type, size, body):
			Atom.__init__(self, type, size, body)
		def write(self, stream):
			self.write1(stream)
			write_uint(stream, self.body[0])
			write_uint(stream, len(self.body[1]))
			for sample_count, sample_duration in self.body[1]:
				write_uint(stream, sample_count)
				write_uint(stream, sample_duration)
		def calsize(self):
			oldsize = self.size # TODO: remove
			self.size = 8 + 4 + 4 + len(self.body[1]) * 8
			assert oldsize == self.size, '%s: %d, %d' % (self.type, oldsize, self.size) # TODO: remove
			return self.size
	return stts_atom('stts', size, (value, samples))

def read_stss(stream, size, left, type):
	value = read_full_atom(stream)
	left -= 4

	entry_count = read_uint(stream)
	left -= 4

	samples = []
	for i in range(entry_count):
		sample = read_uint(stream)
		samples.append(sample)
		left -= 4
	
	assert left == 0
	#return Atom('stss', size, None)
	class stss_atom(Atom):
		def __init__(self, type, size, body):
			Atom.__init__(self, type, size, body)
		def write(self, stream):
			self.write1(stream)
			write_uint(stream, self.body[0])
			write_uint(stream, len(self.body[1]))
			for sample in self.body[1]:
				write_uint(stream, sample)
		def calsize(self):
			self.size = 8 + 4 + 4 + len(self.body[1]) * 4
			return self.size
	return stss_atom('stss', size, (value, samples))

def read_stsc(stream, size, left, type):
	value = read_full_atom(stream)
	left -= 4

	entry_count = read_uint(stream)
	left -= 4

	chunks = []
	for i in range(entry_count):
		first_chunk = read_uint(stream)
		samples_per_chunk = read_uint(stream)
		sample_description_index = read_uint(stream)
		assert sample_description_index == 1 # what is it?
		chunks.append((first_chunk, samples_per_chunk, sample_description_index))
		left -= 12
	#chunks, samples = zip(*chunks)
	#total = 0
	#for c, s in zip(chunks[1:], samples):
	#	total += c*s
	#print 'total', total
	
	assert left == 0
	#return Atom('stsc', size, None)
	class stsc_atom(Atom):
		def __init__(self, type, size, body):
			Atom.__init__(self, type, size, body)
		def write(self, stream):
			self.write1(stream)
			write_uint(stream, self.body[0])
			write_uint(stream, len(self.body[1]))
			for first_chunk, samples_per_chunk, sample_description_index in self.body[1]:
				write_uint(stream, first_chunk)
				write_uint(stream, samples_per_chunk)
				write_uint(stream, sample_description_index)
		def calsize(self):
			self.size = 8 + 4 + 4 + len(self.body[1]) * 12
			return self.size
	return stsc_atom('stsc', size, (value, chunks))

def read_stsz(stream, size, left, type):
	value = read_full_atom(stream)
	left -= 4

	sample_size = read_uint(stream)
	sample_count = read_uint(stream)
	left -= 8

	assert sample_size == 0
	total = 0
	sizes = []
	if sample_size == 0:
		for i in range(sample_count):
			entry_size = read_uint(stream)
			sizes.append(entry_size)
			total += entry_size
			left -= 4

	assert left == 0
	#return Atom('stsz', size, None)
	class stsz_atom(Atom):
		def __init__(self, type, size, body):
			Atom.__init__(self, type, size, body)
		def write(self, stream):
			self.write1(stream)
			write_uint(stream, self.body[0])
			write_uint(stream, self.body[1])
			write_uint(stream, self.body[2])
			for entry_size in self.body[3]:
				write_uint(stream, entry_size)
		def calsize(self):
			self.size = 8 + 4 + 8 + len(self.body[3]) * 4
			return self.size
	return stsz_atom('stsz', size, (value, sample_size, sample_count, sizes))

def read_stco(stream, size, left, type):
	value = read_full_atom(stream)
	left -= 4

	entry_count = read_uint(stream)
	left -= 4

	offsets = []
	for i in range(entry_count):
		chunk_offset = read_uint(stream)
		offsets.append(chunk_offset)
		left -= 4
	
	assert left == 0
	#return Atom('stco', size, None)
	class stco_atom(Atom):
		def __init__(self, type, size, body):
			Atom.__init__(self, type, size, body)
		def write(self, stream):
			self.write1(stream)
			write_uint(stream, self.body[0])
			write_uint(stream, len(self.body[1]))
			for chunk_offset in self.body[1]:
				write_uint(stream, chunk_offset)
		def calsize(self):
			self.size = 8 + 4 + 4 + len(self.body[1]) * 4
			return self.size
	return stco_atom('stco', size, (value, offsets))

def read_ctts(stream, size, left, type):
	value = read_full_atom(stream)
	left -= 4

	entry_count = read_uint(stream)
	left -= 4

	samples = []
	for i in range(entry_count):
		sample_count = read_uint(stream)
		sample_offset = read_uint(stream)
		samples.append((sample_count, sample_offset))
		left -= 8
	
	assert left == 0
	class ctts_atom(Atom):
		def __init__(self, type, size, body):
			Atom.__init__(self, type, size, body)
		def write(self, stream):
			self.write1(stream)
			write_uint(stream, self.body[0])
			write_uint(stream, len(self.body[1]))
			for sample_count, sample_offset in self.body[1]:
				write_uint(stream, sample_count)
				write_uint(stream, sample_offset)
		def calsize(self):
			self.size = 8 + 4 + 4 + len(self.body[1]) * 8
			return self.size
	return ctts_atom('ctts', size, (value, samples))

def read_smhd(stream, size, left, type):
	body, stream = read_body_stream(stream, left)
	value = read_full_atom(stream)
	left -= 4

	balance = read_ushort(stream)
	assert stream.read(2) == '\x00\x00'
	left -= 4

	assert left == 0
	return Atom('smhd', size, body)

def read_mp4a(stream, size, left, type):
	body, stream = read_body_stream(stream, left)

	assert stream.read(6) == '\x00' * 6
	data_reference_index = read_ushort(stream)
	assert stream.read(8) == '\x00' * 8
	channel_count = read_ushort(stream)
	sample_size = read_ushort(stream)
	assert stream.read(4) == '\x00' * 4
	time_scale = read_ushort(stream)
	assert stream.read(2) == '\x00' * 2
	left -= 28

	atom = read_atom(stream)
	assert atom.type == 'esds'
	left -= atom.size

	assert left == 0
	return Atom('mp4a', size, body)

def read_descriptor(stream):
	tag = read_byte(stream)
	raise NotImplementedError()

def read_esds(stream, size, left, type):
	value = read_uint(stream)
	version = value >> 24
	assert version == 0
	flags = value & 0xffffff
	left -= 4

	body = stream.read(left)
	return Atom('esds', size, None)

def read_composite_atom(stream, size, left, type):
	children = []
	while left > 0:
		atom = read_atom(stream)
		children.append(atom)
		left -= atom.size
	assert left == 0, left
	return CompositeAtom(type, size, children)

def read_mdat(stream, size, left, type):
	source_start = stream.tell()
	source_size = left
	skip(stream, left)
	#return Atom(type, size, None)
	#raise NotImplementedError()
	class mdat_atom(Atom):
		def __init__(self, type, size, body):
			Atom.__init__(self, type, size, body)
		def write(self, stream):
			self.write1(stream)
			self.write2(stream)
		def write2(self, stream):
			source, source_start, source_size = self.body
			original = source.tell()
			source.seek(source_start)
			copy_stream(source, stream, source_size)
		def calsize(self):
			return self.size
	return mdat_atom('mdat', size, (stream, source_start, source_size))

atom_readers = {
	'mvhd': read_mvhd, # merge duration
	'tkhd': read_tkhd, # merge duration
	'mdhd': read_mdhd, # merge duration
	'hdlr': read_hdlr, # nothing
	'vmhd': read_vmhd, # nothing
	'stsd': read_stsd, # nothing
	'avc1': read_avc1, # nothing
	'avcC': read_avcC, # nothing
	'stts': read_stts, # sample_count, sample_duration
	'stss': read_stss, # join indexes
	'stsc': read_stsc, # merge # sample numbers
	'stsz': read_stsz, # merge # samples
	'stco': read_stco, # merge # chunk offsets
	'ctts': read_ctts, # merge
	'smhd': read_smhd, # nothing
	'mp4a': read_mp4a, # nothing
	'esds': read_esds, # noting

	'ftyp': read_raw,
	'yqoo': read_raw,
	'moov': read_composite_atom,
	'trak': read_composite_atom,
	'mdia': read_composite_atom,
	'minf': read_composite_atom,
	'dinf': read_composite_atom,
	'stbl': read_composite_atom,
	'iods': read_raw,
	'dref': read_raw,
	'free': read_raw,
	'edts': read_raw,
	'pasp': read_raw,

	'mdat': read_mdat,
}
#stsd sample descriptions (codec types, initialization etc.) 
#stts (decoding) time-to-sample  
#ctts (composition) time to sample 
#stsc sample-to-chunk, partial data-offset information 
#stsz sample sizes (framing) 
#stz2 compact sample sizes (framing) 
#stco chunk offset, partial data-offset information 
#co64 64-bit chunk offset 
#stss sync sample table (random access points) 
#stsh shadow sync sample table 
#padb sample padding bits 
#stdp sample degradation priority 
#sdtp independent and disposable samples 
#sbgp sample-to-group 
#sgpd sample group description 
#subs sub-sample information


def read_atom(stream):
	header = stream.read(8)
	if not header:
		return
	assert len(header) == 8
	n = 0
	size = struct.unpack('>I', header[:4])[0]
	assert size > 0
	n += 4
	type = header[4:8]
	n += 4
	assert type != 'uuid'
	if size == 1:
		size = read_ulong(stream)
		n += 8

	left = size - n
	if type in atom_readers:
		return atom_readers[type](stream, size, left, type)
	raise NotImplementedError('%s: %d' % (type, left))

def write_atom(stream, atom):
	atom.write(stream)

def parse_atoms(stream):
	atoms = []
	while True:
		atom = read_atom(stream)
		if atom:
			atoms.append(atom)
		else:
			break
	return atoms

def read_mp4(stream):
	atoms = parse_atoms(stream)
	moov = filter(lambda x: x.type == 'moov', atoms)
	mdat = filter(lambda x: x.type == 'mdat', atoms)
	assert len(moov) == 1
	assert len(mdat) == 1
	moov = moov[0]
	mdat = mdat[0]
	return atoms, moov, mdat

##################################################
# merge
##################################################

def merge_stts(samples_list):
	sample_list = []
	for samples in samples_list:
		assert len(samples) == 1
		sample_list.append(samples[0])
	counts, durations = zip(*sample_list)
	assert len(set(durations)) == 1, 'not all durations equal'
	return [(sum(counts), durations[0])]

def merge_stss(samples, sample_number_list):
	results = []
	start = 0
	for samples, sample_number_list in zip(samples, sample_number_list):
		results.extend(map(lambda x: start + x, samples))
		start += sample_number_list
	return results

def merge_stsc(chunks_list, total_chunk_number_list):
	results = []
	chunk_index = 1
	for chunks, total in zip(chunks_list, total_chunk_number_list):
		for i in range(len(chunks)):
			if i < len(chunks) - 1:
				chunk_number = chunks[i+1][0] - chunks[i][0]
			else:
				chunk_number = total + 1 - chunks[i][0]
			sample_number = chunks[i][1]
			description = chunks[i][2]
			results.append((chunk_index, sample_number, description))
			chunk_index += chunk_number
	return results

def merge_stco(offsets_list, mdats):
	offset = 0
	results = []
	for offsets, mdat in zip(offsets_list, mdats):
		results.extend(offset + x - mdat.body[1] for x in offsets)
		offset += mdat.size - 8
	return results

def merge_stsz(sizes_list):
	return sum(sizes_list, [])

def merge_mdats(mdats):
	total_size = sum(x.size - 8 for x in mdats) + 8
	class multi_mdat_atom(Atom):
		def __init__(self, type, size, body):
			Atom.__init__(self, type, size, body)
		def write(self, stream):
			self.write1(stream)
			self.write2(stream)
		def write2(self, stream):
			for mdat in self.body:
				mdat.write2(stream)
		def calsize(self):
			return self.size
	return multi_mdat_atom('mdat', total_size, mdats)

def merge_moov(moovs, mdats):
	mvhd_duration = 0
	for x in moovs:
		mvhd_duration += x.get('mvhd').get('duration')
	tkhd_durations = [0, 0]
	mdhd_durations = [0, 0]
	for x in moovs:
		traks = x.get_all('trak')
		assert len(traks) == 2
		tkhd_durations[0] += traks[0].get('tkhd').get('duration')
		tkhd_durations[1] += traks[1].get('tkhd').get('duration')
		mdhd_durations[0] += traks[0].get('mdia', 'mdhd').get('duration')
		mdhd_durations[1] += traks[1].get('mdia', 'mdhd').get('duration')
	#mvhd_duration = min(mvhd_duration, tkhd_durations)

	trak0s = [x.get_all('trak')[0] for x in moovs]
	trak1s = [x.get_all('trak')[1] for x in moovs]

	stts0 = merge_stts(x.get('mdia', 'minf', 'stbl', 'stts').body[1] for x in trak0s)
	stts1 = merge_stts(x.get('mdia', 'minf', 'stbl', 'stts').body[1] for x in trak1s)

	stss = merge_stss((x.get('mdia', 'minf', 'stbl', 'stss').body[1] for x in trak0s), (len(x.get('mdia', 'minf', 'stbl', 'stsz').body[3]) for x in trak0s))

	stsc0 = merge_stsc((x.get('mdia', 'minf', 'stbl', 'stsc').body[1] for x in trak0s), (len(x.get('mdia', 'minf', 'stbl', 'stco').body[1]) for x in trak0s))
	stsc1 = merge_stsc((x.get('mdia', 'minf', 'stbl', 'stsc').body[1] for x in trak1s), (len(x.get('mdia', 'minf', 'stbl', 'stco').body[1]) for x in trak1s))

	stco0 = merge_stco((x.get('mdia', 'minf', 'stbl', 'stco').body[1] for x in trak0s), mdats)
	stco1 = merge_stco((x.get('mdia', 'minf', 'stbl', 'stco').body[1] for x in trak1s), mdats)

	stsz0 = merge_stsz((x.get('mdia', 'minf', 'stbl', 'stsz').body[3] for x in trak0s))
	stsz1 = merge_stsz((x.get('mdia', 'minf', 'stbl', 'stsz').body[3] for x in trak1s))

	ctts = sum((x.get('mdia', 'minf', 'stbl', 'ctts').body[1] for x in trak0s), [])

	moov = moovs[0]

	moov.get('mvhd').set('duration', mvhd_duration)
	trak0 = moov.get_all('trak')[0]
	trak1 = moov.get_all('trak')[1]
	trak0.get('tkhd').set('duration', tkhd_durations[0])
	trak1.get('tkhd').set('duration', tkhd_durations[1])
	trak0.get('mdia', 'mdhd').set('duration', mdhd_durations[0])
	trak1.get('mdia', 'mdhd').set('duration', mdhd_durations[1])

	stts_atom = trak0.get('mdia', 'minf', 'stbl', 'stts')
	stts_atom.body = stts_atom.body[0], stts0
	stts_atom = trak1.get('mdia', 'minf', 'stbl', 'stts')
	stts_atom.body = stts_atom.body[0], stts1

	stss_atom = trak0.get('mdia', 'minf', 'stbl', 'stss')
	stss_atom.body = stss_atom.body[0], stss

	stsc_atom = trak0.get('mdia', 'minf', 'stbl', 'stsc')
	stsc_atom.body = stsc_atom.body[0], stsc0
	stsc_atom = trak1.get('mdia', 'minf', 'stbl', 'stsc')
	stsc_atom.body = stsc_atom.body[0], stsc1

	stco_atom = trak0.get('mdia', 'minf', 'stbl', 'stco')
	stco_atom.body = stss_atom.body[0], stco0
	stco_atom = trak1.get('mdia', 'minf', 'stbl', 'stco')
	stco_atom.body = stss_atom.body[0], stco1

	stsz_atom = trak0.get('mdia', 'minf', 'stbl', 'stsz')
	stsz_atom.body = stsz_atom.body[0], stsz_atom.body[1], len(stsz0), stsz0
	stsz_atom = trak1.get('mdia', 'minf', 'stbl', 'stsz')
	stsz_atom.body = stsz_atom.body[0], stsz_atom.body[1], len(stsz1), stsz1

	ctts_atom = trak0.get('mdia', 'minf', 'stbl', 'ctts')
	ctts_atom.body = ctts_atom.body[0], ctts

	old_moov_size = moov.size
	new_moov_size = moov.calsize()
	new_mdat_start = mdats[0].body[1] + new_moov_size - old_moov_size
	stco0 = map(lambda x: x + new_mdat_start, stco0)
	stco1 = map(lambda x: x + new_mdat_start, stco1)
	stco_atom = trak0.get('mdia', 'minf', 'stbl', 'stco')
	stco_atom.body = stss_atom.body[0], stco0
	stco_atom = trak1.get('mdia', 'minf', 'stbl', 'stco')
	stco_atom.body = stss_atom.body[0], stco1

	return moov

def merge_mp4s(files, output):
	assert files
	ins = [open(mp4, 'rb') for mp4 in files]
	mp4s = map(read_mp4, ins)
	moovs = map(lambda x: x[1], mp4s)
	mdats = map(lambda x: x[2], mp4s)
	moov = merge_moov(moovs, mdats)
	mdat = merge_mdats(mdats)
	with open(output, 'wb') as output:
		for x in mp4s[0][0]:
			if x.type == 'moov':
				moov.write(output)
			elif x.type == 'mdat':
				mdat.write(output)
			else:
				x.write(output)

##################################################
# main
##################################################

# TODO: FIXME: duplicate of flv_join

def guess_output(inputs):
	import os.path
	inputs = map(os.path.basename, inputs)
	n = min(map(len, inputs))
	for i in reversed(range(1, n)):
		if len(set(s[:i] for s in inputs)) == 1:
			return inputs[0][:i] + '.mp4'
	return 'output.mp4'

def concat_mp4s(mp4s, output=None):
	assert mp4s, 'no mp4 file found'
	import os.path
	if not output:
		output = guess_output(mp4s)
	elif os.path.isdir(output):
		output = os.path.join(output, guess_output(mp4s))

	print 'Joining %s into %s' % (', '.join(mp4s), output)
	merge_mp4s(mp4s, output)

	return output

def usage():
	print 'python mp4_join.py --output target.mp4 mp4...'

def main():
	import sys, getopt
	try:
		opts, args = getopt.getopt(sys.argv[1:], "ho:", ["help", "output="])
	except getopt.GetoptError, err:
		usage()
		sys.exit(1)
	output = None
	for o, a in opts:
		if o in ("-h", "--help"):
			usage()
			sys.exit()
		elif o in ("-o", "--output"):
			output = a
		else:
			usage()
			sys.exit(1)
	if not args:
		usage()
		sys.exit(1)

	concat_mp4s(args, output)

if __name__ == '__main__':
	main()

